{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d4OeeljJJikY"
      },
      "source": [
        "# About Neural ODE : Using `Torchdiffeq` with `Deepchem`\n",
        "\n",
        "Author : [Anshuman Mishra](https://github.com/shivance) : [Linkedin](https://www.linkedin.com/in/anshumon/)\n",
        "\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1x1_EexmiTk01dsAbdfopWKWbIGENiXC1?usp=sharing)\n",
        "\n",
        "\n",
        "Before getting our hands dirty with code , let us first understand little bit about what Neural ODEs are ?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VVgZYeQ_n3qv"
      },
      "source": [
        "#NeuralODEs and torchdiffeq\n",
        "\n",
        "NeuralODE stands for \"Neural Ordinary Differential Equation. You heard right. Let me guess . Your first impression of the word is : \"Has it something to do with differential equations that we studied in the school ?\" \n",
        "\n",
        "Spot on ! Let's see the formal definition as stated by the original [paper](https://arxiv.org/pdf/1806.07366.pdf) : \n",
        "\n",
        "\n",
        "\n",
        "```\n",
        "Neural ODEs are a new family of deep neural network models. Instead of specifying a discrete sequence of \n",
        "hidden layers, we parameterize the derivative of the hidden state using a neural network.\n",
        "\n",
        "The output of the network is computed using a blackbox differential equation solver.These are continuous-depth models that have constant memory \n",
        "cost, adapt their evaluation strategy to each input, and can explicitly trade numerical precision for speed.\n",
        "```\n",
        "\n",
        "\n",
        "In simple words perceive NeuralODEs as yet another type of layer like Linear, Conv2D, MHA...\n",
        "\n",
        "\n",
        "\n",
        "In this tutorial we will be using [torchdiffeq](https://github.com/rtqichen/torchdiffeq). This library provides ordinary differential equation (ODE) solvers implemented in PyTorch framework. The library provides a clean API of ODE solvers for usage in deep learning applications. As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A49L9T0MrQW3"
      },
      "source": [
        "## What will you learn after completing this tutorial ?\n",
        "\n",
        "\n",
        "\n",
        "1.   How to implement a Neural ODE in a Neural Network ?\n",
        "2.   Using torchdiffeq with deepchem.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cJ_f-vywL_6G"
      },
      "source": [
        "### Installing Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BA8x2pcIWZK-",
        "outputId": "dabaa338-7105-46c9-890b-631ecdb1f18b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting torchdiffeq\n",
            "  Downloading torchdiffeq-0.2.2-py3-none-any.whl (31 kB)\n",
            "Requirement already satisfied: torch>=1.3.0 in /usr/local/lib/python3.7/dist-packages (from torchdiffeq) (1.10.0+cu111)\n",
            "Requirement already satisfied: scipy>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from torchdiffeq) (1.4.1)\n",
            "Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.7/dist-packages (from scipy>=1.4.0->torchdiffeq) (1.21.5)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.3.0->torchdiffeq) (3.10.0.2)\n",
            "Installing collected packages: torchdiffeq\n",
            "Successfully installed torchdiffeq-0.2.2\n",
            "Collecting deepchem\n",
            "  Downloading deepchem-2.6.1-py3-none-any.whl (608 kB)\n",
            "\u001b[K     |████████████████████████████████| 608 kB 8.9 MB/s \n",
            "\u001b[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from deepchem) (1.4.1)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from deepchem) (1.0.2)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from deepchem) (1.3.5)\n",
            "Collecting rdkit-pypi\n",
            "  Downloading rdkit_pypi-2021.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.6 MB)\n",
            "\u001b[K     |████████████████████████████████| 20.6 MB 8.2 MB/s \n",
            "\u001b[?25hRequirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from deepchem) (1.1.0)\n",
            "Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from deepchem) (1.21.5)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->deepchem) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->deepchem) (2018.9)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->deepchem) (1.15.0)\n",
            "Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from rdkit-pypi->deepchem) (7.1.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->deepchem) (3.1.0)\n",
            "Installing collected packages: rdkit-pypi, deepchem\n",
            "Successfully installed deepchem-2.6.1 rdkit-pypi-2021.9.4\n"
          ]
        }
      ],
      "source": [
        "!pip install torchdiffeq\n",
        "!pip install --pre deepchem"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gngfRPYhJhaj"
      },
      "source": [
        "### Import Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "oB_zAXmxXEsV"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "from torchdiffeq import odeint\n",
        "import math\n",
        "import numpy as np\n",
        "\n",
        "import deepchem as dc\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FxT3gkFwuZyz"
      },
      "source": [
        "Before diving into the core of this tutorial , let's first acquaint ourselves with usage of torchdiffeq. Let's solve following differential equation .\n",
        "\n",
        "$ \\frac{dz(t)}{dt} = f(t) = t $\n",
        "\n",
        "when $z(0) = 0$\n",
        "\n",
        "The process to do it by hand is :\n",
        "\n",
        "$\\int dz = \\int tdt+C　\\\\\\ z(t) = \\frac{t^2}{2} + C$\n",
        "\n",
        "\n",
        " \n",
        "Let's solve it using ODE Solver called `odeint` from torchdiffeq"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "kWKRr4e2uW2_"
      },
      "outputs": [],
      "source": [
        "def f(t,z):\n",
        "  return t\n",
        "\n",
        "z0 = torch.Tensor([0])\n",
        "t = torch.linspace(0,2,100)\n",
        "out = odeint(f, z0, t);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9P6E7qnDy7BC"
      },
      "source": [
        "Let's plot our result .It should be a parabola (remember general equation of parabola as $x^2 = 4ay$ )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 320
        },
        "id": "myKWP7aJzghx",
        "outputId": "3094698f-6ab5-4e3f-8cb6-965be4af1656"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
            "  \n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAbNUlEQVR4nO3df7BcZZ3n8fcnMWSFWBiJtmwICTpQ6sjKjy5gSmu80YiR2SFODTuEYRFcU3d0ZXXYzVbBpAqmEGoYB0Ud2IFbkAK3rlxnHdFoxWEStIfdZeImYaNXQCFGA8kyiZOwwWvYYJLv/tGnM51O9+1z7z3dffr051XVdbvPc/r0882B733u9zynH0UEZmZWXLN63QEzM+ssJ3ozs4JzojczKzgnejOzgnOiNzMruNf0ugPNLFiwIJYsWZLJsX71q19xyimnZHKsXilCDFCMOBxDfhQhjixj2Lp16z9FxBubteUy0S9ZsoQtW7ZkcqxKpcLQ0FAmx+qVIsQAxYjDMeRHEeLIMgZJO1u1uXRjZlZwTvRmZgXnRG9mVnBO9GZmBedEb2ZWcG0TvaRFkr4n6WlJT0n6dJN9JOlLkrZL+qGkC+rarpX0XPK4NusAzGywjI6PsuQLS3jf37+PJV9Ywuj4aK+7lHtpplceBv5TRDwp6XXAVkkbIuLpun0+BJydPC4G/gq4WNIbgFuAMhDJe9dFxEuZRmFmA2F0fJThbw1z8NcHAdh5YCfD3xoG4Opzr+5l13Kt7Yg+Il6MiCeT578EngEWNuy2AvhyVG0CXi/pdOCDwIaI2J8k9w3A8kwjMLOBseaxNceSfM3BXx9kzWNretSj/jClG6YkLQHOB77f0LQQeKHu9a5kW6vtzY49DAwDlEolKpXKVLrW0sTERGbH6pUixADFiMMx9NbzB55vub0fY+rWuUid6CXNA/4G+OOIeDnrjkTECDACUC6XI6u7xXz3XH4UIQ7H0Duj46PM0iyOxJET2s489cy+jKlb5yLVrBtJc6gm+dGI+HqTXXYDi+pen5Fsa7XdzCy1Wm2+WZI/ec7J3P7+23vQq/6RZtaNgAeAZyLi8y12Wwd8JJl9cwlwICJeBB4FLpU0X9J84NJkm5lZas1q8wCzNZuR3x3xhdg20pRu3g1cA4xL2pZs+xPgTICIuBdYD1wGbAcOAh9N2vZL+gywOXnfrRGxP7vum9kgaFWbPxpHneRTaJvoI+J/AGqzTwCfbNG2Flg7rd6Z2cBrV5u39nxnrJnl1mS1+bmz5ro2n5ITvZnl1mS1+dXnrHbZJiUnejPLrclq88tKy7rcm/7lRG9muVSrzTfj2vzUONGbWe543ny2nOjNLHc8bz5bTvRmljueN58tJ3ozyxXX5rPnRG9mueHafGc40ZtZbrg23xlO9GaWG67Nd4YTvZnlgmvzneNEb2Y959p8ZznRm1nPuTbfWU70ZtZTo+Oj7Dyws2mba/PZcKI3s56plWxacW0+G20XHpG0FvjXwN6IeGeT9v8M1H7lvgZ4O/DGZHWpnwO/BI4AhyOinFXHzaz/tSrZgGvzWUozon8QWN6qMSL+IiLOi4jzgJuAv29YLnBp0u4kb2bHaTWdEnBtPkNtE31EPA6kXef1KuDhGfXIzAbCZNMpF5+62Ek+Q6ou99pmJ2kJ8O1mpZu6fU4GdgG/URvRS/oZ8BIQwH0RMTLJ+4eBYYBSqXTh2NhY+igmMTExwbx58zI5Vq8UIQYoRhyOIRsb92zkzmfv5NDRQye0zZ01l9XnrG67sEge4pipLGNYunTp1laVk7Y1+in4XeB/NpRt3hMRuyW9Cdgg6cfJXwgnSH4JjACUy+UYGhrKpFOVSoWsjtUrRYgBihGHY8jGdV+4rmmSn63ZPPDhB1KN5vMQx0x1K4YsZ92spKFsExG7k597gUeAizL8PDPrU/6qg+7KJNFLOhV4L/DNum2nSHpd7TlwKfCjLD7PzPqXv+qg+9JMr3wYGAIWSNoF3ALMAYiIe5Pdfg/4u4j4Vd1bS8Ajkmqf85WI+Nvsum5m/cZfddAbbRN9RFyVYp8HqU7DrN+2A3jXdDtmZsXjrzroDd8Za2Zd49p8bzjRm1lXuDbfO070ZtZxrs33lhO9mXWca/O95URvZh3lryHuPSd6M+sYfw1xPjjRm1nH+GuI88GJ3sw6xl9DnA9O9GbWEf4a4vxwojezzHk6Zb440ZtZ5jydMl+c6M0sU55OmT9O9GaWGU+nzCcnejPLjKdT5pMTvZllxtMp88mJ3swy4emU+dU20UtaK2mvpKbLAEoaknRA0rbkcXNd23JJP5G0XdKNWXbczPLD0ynzLc2I/kFgeZt9/ntEnJc8bgWQNBu4B/gQ8A7gKknvmElnzSyfPJ0y39om+oh4HNg/jWNfBGyPiB0R8SowBqyYxnHMLMc8nTL/2q4Zm9JvSfoB8H+A1RHxFLAQeKFun13Axa0OIGkYGAYolUpUKpVMOjYxMZHZsXqlCDFAMeJwDMfbuGcjdz57Z8v2N819U8f+vXwu0ssi0T8JLI6ICUmXAd8Azp7qQSJiBBgBKJfLMTQ0lEHXoFKpkNWxeqUIMUAx4nAMx7vuC9dx6Oihpm0nzzmZz/3O5xg6N5vPauRzkd6MZ91ExMsRMZE8Xw/MkbQA2A0sqtv1jGSbmRWEp1P2hxkneklvlqTk+UXJMfcBm4GzJZ0l6SRgJbBupp9nZvng6ZT9o23pRtLDwBCwQNIu4BZgDkBE3AtcAXxC0mHgFWBlRARwWNL1wKPAbGBtUrs3sz7n6ZT9pW2ij4ir2rTfDdzdom09sH56XTOzvPJ0yv7iO2PNbEo8nbL/ONGbWWr+dsr+5ERvZqn52yn7kxO9maUyWckGPJ0yz5zozaytdiUbT6fMNyd6M2vLJZv+5kRvZm35Dtj+5kRvZpPyHbD9z4nezFryHbDF4ERvZi35DthicKI3s6Z8B2xxONGb2Ql8B2yxONGb2Qk8nbJYnOjN7Di+A7Z4nOjN7BjfAVtMbRO9pLWS9kr6UYv2qyX9UNK4pCckvauu7efJ9m2StmTZcTPLnks2xZRmRP8gsHyS9p8B742Ic4HPkCzwXWdpRJwXEeXpddHMusElm+JKs8LU45KWTNL+RN3LTVQXATezPuKSTbGpurxrm52qif7bEfHONvutBt4WEauS1z8DXgICuC8iGkf79e8dBoYBSqXShWNjYylDmNzExATz5s3L5Fi9UoQYoBhxFDWGlZtWsufQnqb7z501l9XnrGZZaVk3updaUc/FdC1dunRrq8pJ2xF9WpKWAh8D3lO3+T0RsVvSm4ANkn4cEY83e3/yS2AEoFwux9DQUCb9qlQqZHWsXilCDFCMOIoYw+j4aMskD/DAhx/I5Wi+iOeiUzKZdSPpXwH3AysiYl9te0TsTn7uBR4BLsri88wsGy7ZDIYZJ3pJZwJfB66JiGfrtp8i6XW158ClQNOZO2bWG55lMxjalm4kPQwMAQsk7QJuAeYARMS9wM3AacB/kQRwOKkTlYBHkm2vAb4SEX/bgRjMbBo8y2ZwpJl1c1Wb9lXAqibbdwDvOvEdZtZrLtkMFt8ZazaAXLIZLE70ZgNm456NLtkMGCd6swEyOj7Knc/e2bLdJZticqI3GyBrHlvDoaOHmra5ZFNcTvRmA8KzbAaXE73ZAPAsm8HmRG82ADzLZrA50ZsVnEs25kRvVmAu2Rg40ZsVmks2Bk70ZoXlko3VONGbFZBLNlbPid6sgCYr2cydNdclmwHjRG9WMO1KNqvPWe3R/IBxojcrkDQlm7yt/Wqd50RvViCeZWPNpEr0ktZK2iup6VKAqvqSpO2Sfijpgrq2ayU9lzyuzarjZnY8z7KxVtKO6B8Elk/S/iHg7OQxDPwVgKQ3UF168GKqC4PfImn+dDtrZs15lo1NJlWij4jHgf2T7LIC+HJUbQJeL+l04IPAhojYHxEvARuY/BeGmU2DSzY2mbZrxqa0EHih7vWuZFur7SeQNEz1rwFKpRKVSiWTjk1MTGR2rF4pQgxQjDjyGEO7FaNueOsNLNy38Fi/8xjDdBQhjm7FkFWin7GIGAFGAMrlcgwNDWVy3EqlQlbH6pUixADFiCNvMYyOj3LXE3e1bF986mJuu/K247blLYbpKkIc3Yohq1k3u4FFda/PSLa12m5mGXDJxtLIKtGvAz6SzL65BDgQES8CjwKXSpqfXIS9NNlmZjPkWTaWVqrSjaSHgSFggaRdVGfSzAGIiHuB9cBlwHbgIPDRpG2/pM8Am5ND3RoRk13UNbMUPMvGpiJVoo+Iq9q0B/DJFm1rgbVT75qZNTM6Psq1j1zLkTjStN0lG2vkO2PN+khtJN8qyYNLNnYiJ3qzPjLZxVdwycaac6I36xPtLr66ZGOtONGb9YF2F19na7ZLNtaSE71ZH2g3X/6h33vISd5acqI3yznPl7eZcqI3yzHPl7cs5Oa7bszseJ4vb1nxiN4shzxf3rLkRG+WQ54vb1lyojfLGc+Xt6w50ZvliOfLWyf4YqxZTqS5+Ookb9PhEb1ZDvjiq3WSE71ZDvjiq3WSE71Zj/niq3VaqkQvabmkn0jaLunGJu13SdqWPJ6V9H/r2o7Uta3LsvNm/c4XX60b2l6MlTQbuAf4ALAL2CxpXUQ8XdsnIm6o2/8/AOfXHeKViDgvuy6bFYMvvlq3pBnRXwRsj4gdEfEqMAasmGT/q4CHs+icWVH54qt1k6rLvU6yg3QFsDwiViWvrwEujojrm+y7GNgEnBFR/S9Y0mFgG3AYuCMivtHic4aBYYBSqXTh2NjYtIOqNzExwbx58zI5Vq8UIQYoRhxZxbBy00r2HNrTsr00t8TYJdn8P9CoCOcBihFHljEsXbp0a0SUm7VlPY9+JfC1WpJPLI6I3ZLeAnxX0nhE/LTxjRExAowAlMvlGBoayqRDlUqFrI7VK0WIAYoRRxYxjI6PTprkT55zMp/7nc8xdO7MPqeVIpwHKEYc3YohTelmN7Co7vUZybZmVtJQtomI3cnPHUCF4+v3ZgPFF1+tF9Ik+s3A2ZLOknQS1WR+wuwZSW8D5gP/ULdtvqS5yfMFwLuBpxvfazYIahdfvVKUdVvb0k1EHJZ0PfAoMBtYGxFPSboV2BIRtaS/EhiL44v+bwfuk3SU6i+VO+pn65gNCl98tV5KVaOPiPXA+oZtNze8/tMm73sCOHcG/TPre+2mUYLvfLXO8p2xZh2UZiTvO1+t05zozTqo3XfY+OKrdYMTvVmHpPkOG198tW5wojfrAE+jtDzxwiNmGfN32FjeeERvliFPo7Q88ojeLCOeRml55RG9WQY8jdLyzCN6sxlKM5L3xVfrJY/ozWYg7Uje0yitlzyiN5smj+StX3hEbzYNHslbP/GI3myKPJK3fuMRvdkUbNyz0SN56zse0ZulNDo+yp/9+M84ytGW+3gkb3mUakQvabmkn0jaLunGJu3XSfqFpG3JY1Vd27WSnkse12bZebNuqdXkJ0vyHslbXrUd0UuaDdwDfADYBWyWtK7JSlFfjYjrG977BuAWoAwEsDV570uZ9N6sC1yTt36XZkR/EbA9InZExKvAGLAi5fE/CGyIiP1Jct8ALJ9eV826z7NrrAjS1OgXAi/Uvd4FXNxkv9+X9NvAs8ANEfFCi/cubPYhkoaBYYBSqUSlUknRtfYmJiYyO1avFCEG6L84Nu7Z2LYmP4tZ3PDWG1i4b2HfxNZv56GVIsTRrRiyuhj7LeDhiDgk6Y+Ah4D3TeUAETECjACUy+UYGhrKpGOVSoWsjtUrRYgB+iuO0fFR7nrirrY1+X4s1/TTeZhMEeLoVgxpSje7gUV1r89Ith0TEfsi4lDy8n7gwrTvNcubWk3eSwBaUaRJ9JuBsyWdJekkYCWwrn4HSafXvbwceCZ5/ihwqaT5kuYDlybbzHLJNXkroralm4g4LOl6qgl6NrA2Ip6SdCuwJSLWAZ+SdDlwGNgPXJe8d7+kz1D9ZQFwa0Ts70AcZjOWZnbNLGZ5JG99J1WNPiLWA+sbtt1c9/wm4KYW710LrJ1BH806Lu1I/oa33uAkb33Hd8bawJvKPPmF+5pOGjPLNX/XjQ001+RtEHhEbwPLd7zaoPCI3gaSR/I2SDyit4HjkbwNGid6Gxij46N8+jufZt8r+ybdr1/veDVrxYneBkKtVDPZ3a7gkbwVkxO9FV6aUg14JG/F5YuxVmhpLrqCR/JWbB7RW2F5JG9W5URvhZP2oivAaa89jS9+6ItO8lZoTvRWKFO56Oo58jYonOitMFyqMWvOid763lRKNb7oaoPIid76WtpSDXgkb4PLid76VtpSDfiiqw22VIle0nLgi1RXmLo/Iu5oaP+PwCqqK0z9Avh3EbEzaTsCjCe7Ph8Rl2fUdxtQUy3V+KKrDbq2iV7SbOAe4APALmCzpHUR8XTdbv8bKEfEQUmfAD4LXJm0vRIR52XcbxtQLtWYTV2aO2MvArZHxI6IeBUYA1bU7xAR34uI2v95m4Azsu2m2T+XatIk+dNee5qTvFlCETH5DtIVwPKIWJW8vga4OCKub7H/3cA/RsRtyevDwDaqZZ07IuIbLd43DAwDlEqlC8fGxqYXUYOJiQnmzZuXybF6pQgxwPTj2LhnI3/53F/y8pGX2+47i1nc9LabWFZaNp0utlWEc1GEGKAYcWQZw9KlS7dGRLlZW6YXYyX9W6AMvLdu8+KI2C3pLcB3JY1HxE8b3xsRI8AIQLlcjqGhoUz6VKlUyOpYvVKEGGDqcUylFg/dKdUU4VwUIQYoRhzdiiFN6WY3sKju9RnJtuNIWgasAS6PiEO17RGxO/m5A6gA58+gvzYgarX4tEnepRqz1tKM6DcDZ0s6i2qCXwn8Yf0Oks4H7qNa4tlbt30+cDAiDklaALyb6oVas5amMm3Ss2rM2mub6CPisKTrgUepTq9cGxFPSboV2BIR64C/AOYB/00S/PM0yrcD90k6SvWvhzsaZuuYHZPHUo1ZEaSq0UfEemB9w7ab6543vfIVEU8A586kg1Z8U03w4BugzKbCd8ZazzjBm3WHE7113XQSvGvxZtPnRG9ds3HPRq747BVTSvDgWrzZTDnRW8dNZwRf41KN2cw50VvHOMGb5YMTvWXOCd4sX5zoLTNO8Gb55ERvM+YEb5ZvTvQ2bU7wZv3Bid5SGx0fZc1ja9h5YCdCBJN/xXUzTvBm3edEb201G7lPNcmf9trT+Pjij3Pblbdl3T0za8OJ3k6Qxci9pn4EX6lUsuukmaXmRG/HZDFyr3GJxiw/nOgHUP2IfbZmcySOzHjkXuMEb5Y/TvQDYLJSTG1xj5kmeSd4s/xyoi+AZiP0ViP1LEbtALM0i6NxlMWnLub299/uBG+WY6kSvaTlwBeprjB1f0Tc0dA+F/gycCGwD7gyIn6etN0EfAw4AnwqIh7NrPcFd1wCf/z4BN4qkddG6FmN1Bt55G7Wf9omekmzgXuADwC7gM2S1jUsCfgx4KWI+A1JK4E/B66U9A6qa8z+JvAvgY2SzolIsRjogKstjn3w1weBExN4pxJ5I4/czfpfmhH9RcD2iNgBIGkMWAHUJ/oVwJ8mz78G3K3q4rErgLGIOAT8TNL25Hj/kE33i2vNY2uOJfle8MjdrDjSJPqFwAt1r3cBF7faJ1lM/ABwWrJ9U8N7Fzb7EEnDwDBAqVTKbM71xMREX87ffv7A8135nFrpZxazOMpRSnNLrDprFctKy2Afmf7b9eu5qOcY8qMIcXQrhtxcjI2IEWAEoFwux9DQUCbHrVQqZHWsbjpz25nsPLAz8+P2shTTr+einmPIjyLE0a0Y0iT63cCiutdnJNua7bNL0muAU6lelE3zXmvi9vffflyNPq1aIm+8aOsau9ngSpPoNwNnSzqLapJeCfxhwz7rgGup1t6vAL4bESFpHfAVSZ+nejH2bOB/ZdX5Iqsl5FbTJp3IzSyttok+qblfDzxKdXrl2oh4StKtwJaIWAc8APzX5GLrfqq/DEj2+2uqF24PA5/0jJv0rj736mPfEdPvf6KaWe+kqtFHxHpgfcO2m+ue/z/g37R47+3A7TPoo5mZzcCsXnfAzMw6y4nezKzgnOjNzArOid7MrOAU0dnvSpkOSb8AsrpbaAHwTxkdq1eKEAMUIw7HkB9FiCPLGBZHxBubNeQy0WdJ0paIKPe6HzNRhBigGHE4hvwoQhzdisGlGzOzgnOiNzMruEFI9CO97kAGihADFCMOx5AfRYijKzEUvkZvZjboBmFEb2Y20JzozcwKrnCJXtIbJG2Q9Fzyc36L/Y5I2pY81nW7n81IWi7pJ5K2S7qxSftcSV9N2r8vaUn3ezm5FDFcJ+kXdf/2q3rRz8lIWitpr6QftWiXpC8lMf5Q0gXd7mMaKeIYknSg7lzc3Gy/XpK0SNL3JD0t6SlJn26yT67PR8oYOnsuIqJQD+CzwI3J8xuBP2+x30Sv+9rQn9nAT4G3ACcBPwDe0bDPvwfuTZ6vBL7a635PI4brgLt73dc2cfw2cAHwoxbtlwHfAQRcAny/132eZhxDwLd73c82MZwOXJA8fx3wbJP/pnJ9PlLG0NFzUbgRPdUFyR9Knj8EfLiHfZmKY4uwR8SrQG0R9nr1sX0NeH+yCHtepIkh9yLicarrKrSyAvhyVG0CXi/p9O70Lr0UceReRLwYEU8mz38JPMOJ607n+nykjKGjipjoSxHxYvL8H4FSi/3+haQtkjZJysMvg2aLsDf+x3DcIuxAbRH2vEgTA8DvJ39if03SoibteZc2zn7wW5J+IOk7kn6z152ZTFKqPB/4fkNT35yPSWKADp6L3CwOPhWSNgJvbtK0pv5FRISkVvNHF0fEbklvAb4raTwifpp1X+0E3wIejohDkv6I6l8o7+txnwbVk1T/P5iQdBnwDarLfeaOpHnA3wB/HBEv97o/09Emho6ei74c0UfEsoh4Z5PHN4E9tT/bkp97Wxxjd/JzB1Ch+lu2l6ayCDsNi7DnRdsYImJfRBxKXt4PXNilvmWpEIveR8TLETGRPF8PzJG0oMfdOoGkOVQT5GhEfL3JLrk/H+1i6PS56MtE30ZtoXKSn99s3EHSfElzk+cLgHdTXde2l44twi7pJKoXWxtnA9XHdmwR9i72sZ22MTTUTi+nWq/sN+uAjySzPS4BDtSVC/uGpDfXrvFIuohqPsjTwIGkfw8Az0TE51vsluvzkSaGTp+LvizdtHEH8NeSPkb1q47/AEBSGfh4RKwC3g7cJ+ko1X/QOyKip4k+ZrAIe16kjOFTki6nulj8fqqzcHJF0sNUZ0EskLQLuAWYAxAR91JdP/kyYDtwEPhob3o6uRRxXAF8QtJh4BVgZc4GDlAdhF0DjEvalmz7E+BM6JvzkSaGjp4LfwWCmVnBFbF0Y2ZmdZzozcwKzonezKzgnOjNzArOid7MrOCc6M3MCs6J3sys4P4/dpxhfCphj4kAAAAASUVORK5CYII=",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "plt.plot(t, out, 'go--')\n",
        "plt.axes().set_aspect('equal','datalim')\n",
        "plt.grid()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RgkBAR46yUwM"
      },
      "source": [
        "# What is Neural Differential Equation ?\n",
        "\n",
        "A neural differential equation is a differential equation using a neural network to parameterize the vector field. The canonical example is a neural ordinary differential equation :\n",
        "\n",
        "$y(0) = y_0$\n",
        "\n",
        "$\\frac{dy}{dt} (t) = f_\\theta(t,y(t)) $\n",
        "\n",
        "Here θ represents some vector of learnt parameters, $ f_\\theta : \\mathbb{R} \\times \\mathbb{R}^{d_1 \\times ... \\times d_k}$ is any standard neural architecture and $ y:[0, T] → \\mathbb{R}^{d_1 \\times ... d_k} $ is the solution. For many applications $f_\\theta$ will just be a simple feedforward network. Here $d_i $ is the dimension. \n",
        "\n",
        "\n",
        "[Reference](https://arxiv.org/pdf/2202.02435.pdf)\n",
        "\n",
        "The central idea now is to use a differential equation solver as part of a learnt differentiable computation graph (the sort of computation graph ubiquitous to deep\n",
        "learning)\n",
        "\n",
        "![sample.jpg]()\n",
        "\n",
        "\n",
        "As simple example, suppose we observe some picture $y_0 \\in \\mathbb{R}^{3 \\times 32 \\times 3}$ (RGB and 32x32 pixels), and wish to classify it as a picture of a cat or as a picture of a dog.\n",
        " \n",
        " "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S7ClNiTArtnd"
      },
      "source": [
        "With torchdiffeq , we can solve even complex higher order differential equations too. Following is a real world example , a set of differential equations that models a spring - mass damper system\n",
        "\n",
        "$\\dot{x}= \\frac{dx}{dt} $\n",
        "\n",
        "$\\ddot{x} = -(k/m) x + p \\dot{x} $\n",
        "\n",
        "$\\dddot{x} = -r \\ddot{x} + gx$\n",
        "\n",
        "with initial state t=0 , x=1\n",
        "\n",
        "\n",
        "\n",
        "$$\n",
        "\\left[ \\begin{array}{c} \\dot{x} \\\\\\ \\ddot{x} \\\\\\ \\dddot{x} \\end{array} \\right] = \\left[\\begin{array}{cc} 0 & 1 & 0\\\\\\ -\\frac{k}{m} & p & 0\\\\\\ 0 & g & -r \\end{array} \\right]\n",
        "\\left[ \\begin{array}{c} x \\\\\\ \\dot{x}\\\\\\ \\ddot{x} \\\\\\ \\end{array} \\right]\n",
        "$$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LEUerc0E7Uv0"
      },
      "source": [
        "The right hand side may be regarded as a particular differentiable computation graph. The parameters may be fitted by setting up a loss between the trajectories of the model and the observed trajectories in the data, backpropagating through the model, and applying stochastic gradient descent.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 313
        },
        "id": "PymaV9V_tzzu",
        "outputId": "c0528bdb-4cc0-4e58-df8d-c3d633b3b352"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:25: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAEDCAYAAAAcI05xAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASUklEQVR4nO3de7BdZX3G8e8vV3IBgg1kMDAkIsNlMgqcCCqtEtA2BQc6zjiDVQRF0VFatc60Oupoq9Oq7XRgRoqgWJQiGYpUULwxkAioCByqXBIi1HBJwlVIw0lITJpf/1j7aMSEs/c+Z+31Lvh+ZvasfVl7v8+Ck+e8ebP23pGZSJLKNanpAJKk52dRS1LhLGpJKpxFLUmFs6glqXAWtSQVrraijoivRsTjEXF3F/u+LyLuioifR8TNEXFE5/4/iojlETESEV+sK6sklSzqOo86Il4HjABfz8xFY+y7V2Zu7Fw/BXh/Zi6NiFnAUcAiYFFmnlNLWEkqWG0z6sy8EXhq5/si4uCI+H5EDEfETRFxWGffjTvtNgvIzv2bMvNmYEtdOSWpdFMGPN5FwPsy876IOBb4N+AEgIj4APA3wLTR+yRJAyzqiJgNvBb4z4gYvXv66JXMPB84PyL+EvgEcMagsklSyQY5o54EbMjMI8fYbxlwwQDySFIrDOz0vM469JqIeAtAVF7ZuX7ITrueDNw3qFySVLo6z/q4HDgemAs8BnwKuIFqtrw/MBVYlpn/EBHnAW8AtgFPA+dk5j2d13kA2Itq7XoD8KeZubKW0JJUoNqKWpI0MXxnoiQVrpZ/TJw7d24uWLCgr+du2rSJWbNmTWygAWp7fmj/MbQ9P7T/GMzfu+Hh4Sczc99dPpiZE34ZGhrKfi1fvrzv55ag7fkz238Mbc+f2f5jMH/vgNtzN53q0ockFc6ilqTCWdSSVDiLWpIKZ1FLUuEsakkqnEUtSYWzqCVpIlxzDXzhC7W8tEUtSRPh29+Gc8+t5aUtakmaCFu3wvTpY+/XB4takiaCRS1JhbOoJalwFrUkFc6ilqTCWdSSVDiLWpIKZ1FLUuEsakkqnEUtSYWzqCWpcBa1JBXOopakwlnUklSw7dthxw6LWpKKtXVrtbWoJalQFrUkFW7LlmprUUtSoZxRS1LhLGpJKpxFLUmFs6glqXAWtSQVzqKWpMJZ1JJUOItakgrXdFFHxB4RcWtE/CIi7omIv68liSS1Vc1FPaWbCMAJmTkSEVOBmyPie5l5Sy2JJKltmi7qzExgpHNzaueStaSRpDaquaij6uExdoqYDAwDLwfOz8y/28U+ZwNnA8ybN29o2bJlfQUaGRlh9uzZfT23BG3PD+0/hrbnh/Yfw4st/4HLlnHwhRdy07XX8n8zZ/Y15pIlS4Yzc/EuH8zMri/AHGA5sOj59hsaGsp+LV++vO/nlqDt+TPbfwxtz5/Z/mN40eX/zGcyIXPr1r7HBG7P3XRqT2d9ZOaGTlEv7etXhiS9EI0ufUydWsvLd3PWx74RMadzfQbwRuDeWtJIUhtt2QIzZkBELS/fzVkf+wNf66xTTwKuyMzv1JJGktro2Weroq5JN2d93AkcVVsCSWq7movadyZK0ng9+yzssUdtL29RS9J4ja5R18SilqTxculDkgpnUUtS4VyjlqTCuUYtSYVz6UOSCmdRS1LhXKOWpMK5Ri1JBct06UOSirZtG+zYYVFLUrGefbbaukYtSYXasqXaOqOWpEKNzqgtakkqlEUtSYVzjVqSCucatSQVzqUPSSqcRS1JhXONWpIK5xq1JBXOpQ9JKtzmzdXWopakQo0W9axZtQ1hUUvSeGzaBJMnw7RptQ1hUUvSeGzaVM2mI2obwqKWpPEYLeoaWdSSNB4WtSQVbtMmmDmz1iEsakkaD2fUklQ4i1qSCmdRS1LhLGpJKpxFLUmF27zZopakYmWWMaOOiAMjYnlErIyIeyLig7UmkqS22LoVduyovaindLHPduAjmXlHROwJDEfEdZm5stZkklS6TZuqbdMz6sx8JDPv6Fx/BlgFzK81lSS1wWhR1/zOxMjM7neOWADcCCzKzI3Peexs4GyAefPmDS1btqyvQCMjI8yePbuv55ag7fmh/cfQ9vzQ/mN4seSf+eCDHHPmmaz8xCd4/MQTxzXmkiVLhjNz8S4fzMyuLsBsYBh481j7Dg0NZb+WL1/e93NL0Pb8me0/hrbnz2z/Mbxo8t92WyZkXn31uMcEbs/ddGpXZ31ExFTgm8BlmXnVuH5tSNILRSlr1BERwMXAqsz811rTSFKblFLUwHHA6cAJEfHzzuWkWlNJUhsMqKjHPD0vM28G6vuOGUlqq5GRarvnnrUO4zsTJalfGzsnv+21V63DWNSS1K9nnqm2zqglqVAbN8Iee8DUqbUOY1FLUr82bqx9Ng0WtST175lnal+fBotakvq3caNFLUlFe+YZlz4kqWjOqCWpcM6oJalwzqglqXDOqCWpYNu3w7PPOqOWpGKNvn3copakQo1+IJNLH5JUKGfUklQ4Z9SSVLgNG6rt3nvXPpRFLUn9ePrparvPPrUPZVFLUj8sakkqnEUtSYXbsKH69vGav90FLGpJ6s/TTw9kNg0WtST1x6KWpMJZ1JJUOItakgpnUUtS4Z5+GubMGchQFrUk9WrbNhgZcUYtScUa/ZwPi1qSCjXAdyWCRS1JvXvyyWo7d+5AhrOoJalXTzxRbffddyDDWdSS1CuLWpIKZ1FLUuEefxxmz4YZMwYynEUtSb164omBzabBopak3pVW1BHx1Yh4PCLuHkQgSSpeaUUNXAIsrTmHJLVHaUWdmTcCTw0giySVL3PgRR2ZOfZOEQuA72TmoufZ52zgbIB58+YNLVu2rK9AIyMjzJ49u6/nlqDt+aH9x9D2/ND+Y3gh55+8eTN/cvLJ/M9738vDp502YWMuWbJkODMX7/LBzBzzAiwA7u5m38xkaGgo+7V8+fK+n1uCtufPbP8xtD1/ZvuP4QWdf/XqTMi89NIJHRO4PXfTqZ71IUm9WL++2r70pQMb0qKWpF6UWNQRcTnwU+DQiFgbEWfVH0uSCrVuXbWdP39gQ04Za4fMfOsggkhSK6xfX719fM89BzakSx+S1Iv16we67AEWtST1Zt06i1qSirZ+/UDXp8GilqTuZVZFvf/+Ax3Wopakbj32GGzdCgcdNNBhLWpJ6taaNdV24cKBDmtRS1K3Rot6wYKBDmtRS1K3Hnig2lrUklSoNWtgv/1g1qyBDmtRS1K31qwZ+Po0WNSS1D2LWpIKtnVrtUb98pcPfGiLWpK6cf/9sGMHHH74wIe2qCWpG6tWVVuLWpIKNVrUhx468KEtaknqxqpV1VvHZ84c+NAWtSR1Y+XKRpY9wKKWpLFt3VoV9ZFHNjK8RS1JY7nrLti2DYaGGhneopaksQwPV9ujj25keItaksYyPAz77NPIuxLBopaksf3sZ7B4MUQ0MrxFLUnP59e/hjvvhNe/vrEIFrUkPZ+bbqq2FrUkFWrFCpgxA171qsYiWNSStDuZcO218LrXwfTpjcWwqCVpd+69t/rUvFNPbTSGRS1Ju/Otb1XbU05pNIZFLUm7kglf/zq89rUwf36jUSxqSdqVW26plj7OOqvpJBa1JO3SeefBnnvCW97SdBKLWpKea8bDD8MVV8D731+VdcMsakl6joO/9KXqCwI+/OGmowAwpekAklSUK69k7k9+Ap//PMyb13QawBm1JP3O6tXwrnex8bDD4EMfajrNb1nUkgTwy1/CiSfCtGnc8+lPw7RpTSf6LYta0otbJnzjG3DMMfCb38ANN7C1kCWPUV0VdUQsjYjVEXF/RHy07lCSVLuRkd8V9NveBocdVp07/YpXNJ3sD4z5j4kRMRk4H3gjsBa4LSKuycyVdYeTpL5lwpYtVSE/9RQ89FB1Wb26+iKAW2+tHl+wAC6+GN7xDphS5vkVkZnPv0PEa4BPZ+afdW5/DCAz/2l3z1m8eHHefvvtvaeZPp3cto1o6FsUJkJmtjo/tP8Y2p4f2n8Mv5d/dx2z8/GN0UN9htj9Y1OnVp+GN2tWdZk+vbrssQdMn86GDRuYM2dO72MeeSSce25fcSNiODMX7+qxbn59zAce3un2WuDYXQxyNnA2wLx581ixYkXPQf84gkmTJlHD/7LBiWh3fmj/MbQ9P7wgjiFHi7ibXzjj/aX03FJ+7tij/z07t2PHDmLzZmJk5A9easfkycycMYPNe+/NtjlzyB5m2SNr13J/H903lgmb52fmRcBFUM2ojz/++N5fZMsWVqxYQV/PLUTb80P7j6Ht+aH9x9Ca/Dt2VEsj69bBgw/Cr37FpOFhtl93HTMffhgefbRaEvnUp+Cgg8Z8uTnAATXE7Kao1wEH7nT7gM59ktRukybBXntVl8MP/+3dt65YwfFz58KXvwwXXli9nfz88+GMM5qJ2cU+twGHRMTCiJgGnAZcU28sSWrYokXVBzOtXl19DdeZZ8JnP9tIlDGLOjO3A+cAPwBWAVdk5j11B5OkIhx0EFx3HZx+Onzyk3DBBQOP0NUadWZ+F/huzVkkqUxTpsAll8Cvf129tXzx4oF+2a3vTJSkbkyaBJdeCvvtB+95D2zfPrihBzaSJLXdS15SnSf9i19Ub5IZEItaknrx5jfDscfC5z4H27YNZEiLWpJ6EQEf/zg88ABceeVAhrSoJalXJ58MCxcObPnDopakXk2aBO98J1x/PaxZU/9wtY8gSS9Eb397tb3qqtqHsqglqR8LF1afXX311bUPZVFLUr9OPRV+/GN48slah7GoJalfJ51UfQJfDR9tujOLWpL6NTRUffHAj35U6zAWtST1a+pUOO44Z9SSVLTXvx7uvrv6wKaaWNSSNB7Hdr6Z8I47ahvCopak8Tj66Go7PFzbEBa1JI3HPvvAy15mUUtS0YaGLGpJKtpRR1Wf+bFxYy0vb1FL0niNfoP5vffW8vIWtSSN12hRr1pVy8tb1JI0XgcfXL35xaKWpEJNmQKHHOLShyQV7fDDa5tRT6nlVSXpxWbp0uqc6szqexUnkEUtSRPh3e+uLjVw6UOSCmdRS1LhLGpJKpxFLUmFs6glqXAWtSQVzqKWpMJZ1JJUuMjMiX/RiCeAB/t8+lzgyQmMM2htzw/tP4a254f2H4P5e3dQZu67qwdqKerxiIjbM3Nx0zn61fb80P5jaHt+aP8xmH9iufQhSYWzqCWpcCUW9UVNBxintueH9h9D2/ND+4/B/BOouDVqSdLvK3FGLUnaiUUtSYUrpqgjYmlErI6I+yPio03n6VVEHBgRyyNiZUTcExEfbDpTPyJickT8d0R8p+ks/YiIORFxZUTcGxGrIuI1TWfqRUR8uPPzc3dEXB4RezSdaSwR8dWIeDwi7t7pvpdExHURcV9nu0+TGZ/PbvL/c+dn6M6I+K+ImNNkxiKKOiImA+cDfw4cAbw1Io5oNlXPtgMfycwjgFcDH2jhMQB8EKjni98G4zzg+5l5GPBKWnQsETEf+GtgcWYuAiYDpzWbqiuXAEufc99Hgesz8xDg+s7tUl3CH+a/DliUma8Afgl8bNChdlZEUQPHAPdn5q8y8zfAMuDUhjP1JDMfycw7OtefoSqI+c2m6k1EHACcDHyl6Sz9iIi9gdcBFwNk5m8yc0OzqXo2BZgREVOAmcD6hvOMKTNvBJ56zt2nAl/rXP8a8BcDDdWDXeXPzB9m5vbOzVuAAwYebCelFPV84OGdbq+lZSW3s4hYABwF/KzZJD07F/hbYEfTQfq0EHgC+PfO8s1XImJW06G6lZnrgH8BHgIeAf43M3/YbKq+zcvMRzrXHwXmNRlmnN4FfK/JAKUU9QtGRMwGvgl8KDM3Np2nWxHxJuDxzBxuOss4TAGOBi7IzKOATZT9V+7f01nHPZXqF85LgVkR8fZmU41fVucAt/I84Ij4ONWy5mVN5iilqNcBB+50+4DOfa0SEVOpSvqyzLyq6Tw9Og44JSIeoFp6OiEi/qPZSD1bC6zNzNG/yVxJVdxt8QZgTWY+kZnbgKuA1zacqV+PRcT+AJ3t4w3n6VlEnAm8CXhbNvyGk1KK+jbgkIhYGBHTqP4B5ZqGM/UkIoJqbXRVZv5r03l6lZkfy8wDMnMB1X//GzKzVbO5zHwUeDgiDu3cdSKwssFIvXoIeHVEzOz8PJ1Ii/4x9DmuAc7oXD8DuLrBLD2LiKVUy4CnZObmpvMUUdSdRftzgB9Q/WBekZn3NJuqZ8cBp1PNRH/euZzUdKgXob8CLouIO4EjgX9sOE/XOn8TuBK4A7iL6s9nUW9l3pWIuBz4KXBoRKyNiLOAzwFvjIj7qP6m8LkmMz6f3eT/IrAncF3nz/KXGs3oW8glqWxFzKglSbtnUUtS4SxqSSqcRS1JhbOoJalwFrUkFc6ilqTC/T+IQ/y7N/D9VgAAAABJRU5ErkJggg==",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "class SystemOfEquations:\n",
        "\n",
        "  def __init__(self, km, p, g, r):\n",
        "    self.mat = torch.Tensor([[0,1,0],[-km, p, 0],[0,g,-r]])\n",
        "\n",
        "  def solve(self, t, x0, dx0, ddx0):\n",
        "    y0 = torch.cat([x0, dx0, ddx0])\n",
        "    out = odeint(self.func, y0, t)\n",
        "    return out\n",
        "  \n",
        "  def func(self, t, y):\n",
        "    out = y@self.mat \n",
        "    return out\n",
        "\n",
        "\n",
        "x0 = torch.Tensor([1])\n",
        "dx0 = torch.Tensor([0])\n",
        "ddx0 = torch.Tensor([1])\n",
        "\n",
        "t = torch.linspace(0, 4*np.pi, 1000)\n",
        "solver = SystemOfEquations(1,6,3,2)\n",
        "out = solver.solve(t, x0, dx0, ddx0)\n",
        "\n",
        "plt.plot(t, out, 'r')\n",
        "plt.axes()\n",
        "plt.grid()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S-8b2Nlf7uQ2"
      },
      "source": [
        "This is precisely the same procedure as the more general neural ODEs we introduced\n",
        "earlier. At first glance, the NDE approach of ‘putting a neural network in a differential\n",
        "equation’ may seem unusual, but it is actually in line with standard practice. All that\n",
        "has happened is to change the parameterisation of the vector field."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UFrs-sEVeUle"
      },
      "source": [
        "# Model\n",
        "\n",
        "### Let us have a look at how to embed an ODEsolver in a neural network .\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "is0BkdicaiDu"
      },
      "outputs": [],
      "source": [
        "from torchdiffeq import odeint_adjoint as odeadj\n",
        "\n",
        "class f(nn.Module):\n",
        "  def __init__(self, dim):\n",
        "    super(f, self).__init__()\n",
        "    self.model = nn.Sequential(\n",
        "        nn.Linear(dim,124),\n",
        "        nn.ReLU(),\n",
        "        nn.Linear(124,124),\n",
        "        nn.ReLU(),\n",
        "        nn.Linear(124,dim),\n",
        "        nn.Tanh()\n",
        "    )\n",
        "\n",
        "  def forward(self, t, x):\n",
        "    return self.model(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "97VYZesy1-en"
      },
      "source": [
        "`function` f in above code cell , is wrapped in an `nn.Module` (see codecell below) thus forming the dynamics of $\\frac{dy}{dt} (t) = f_\\theta(t,y(t)) $ embedded within a neural Network.\n",
        " ODE Block treats the received input x as the initial value of the differential equation. The integration interval of ODE Block is fixed at [0, 1]. And it returns the output of the layer at $ t = 1 $."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "NcTIcscv2Ai7"
      },
      "outputs": [],
      "source": [
        "class ODEBlock(nn.Module):\n",
        "  \n",
        "  # This is ODEBlock. Think of it as a wrapper over ODE Solver , so as to easily connect it with our neurons !\n",
        "\n",
        "  def __init__(self, f):\n",
        "    super(ODEBlock, self).__init__()\n",
        "    self.f = f\n",
        "    self.integration_time = torch.Tensor([0,1]).float()\n",
        "\n",
        "  def forward(self, x):\n",
        "    self.integration_time = self.integration_time.type_as(x)\n",
        "    out = odeadj(\n",
        "        self.f,\n",
        "        x,\n",
        "        self.integration_time\n",
        "    )\n",
        "\n",
        "    return out[1]\n",
        "\n",
        "\n",
        "class ODENet(nn.Module):\n",
        "  \n",
        "  #This is our main neural network that uses ODEBlock within a sequential module\n",
        "\n",
        "  def __init__(self, in_dim, mid_dim, out_dim):\n",
        "    super(ODENet, self).__init__()\n",
        "    fx = f(dim=mid_dim)\n",
        "    self.fc1 = nn.Linear(in_dim, mid_dim)\n",
        "    self.relu1 = nn.ReLU(inplace=True)\n",
        "    self.norm1 = nn.BatchNorm1d(mid_dim)\n",
        "    self.ode_block = ODEBlock(fx)\n",
        "    self.dropout = nn.Dropout(0.4)\n",
        "    self.norm2 = nn.BatchNorm1d(mid_dim)\n",
        "    self.fc2 = nn.Linear(mid_dim, out_dim)\n",
        "\n",
        "  def forward(self, x):\n",
        "    batch_size = x.shape[0]\n",
        "    x = x.view(batch_size, -1)\n",
        "\n",
        "    out = self.fc1(x)\n",
        "    out = self.relu1(out)\n",
        "    out = self.norm1(out)\n",
        "    out = self.ode_block(out)\n",
        "    out = self.norm2(out)\n",
        "    out = self.dropout(out)\n",
        "    out = self.fc2(out)\n",
        "\n",
        "    return out"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mIkICdmIG-ic"
      },
      "source": [
        "As mentioned before , Neural ODE Networks acts similar (has advantages though) to other neural networks , so we can solve any problem with them as the existing models do. We are gonna reuse the training process mentioned in [this](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Creating_Models_with_TensorFlow_and_PyTorch.ipynb) deepchem tutorial.\n",
        "\n",
        "So Rather than demonstrating how to use NeuralODE model with a normal dataset, we shall use the **Delaney solubility dataset** provided under **deepchem** . Our model will learn to predict the solubilities of molecules based on their extended-connectivity fingerprints (ECFPs) . For performance metrics we use [pearson_r2_score](https://deepchem.readthedocs.io/en/latest/api_reference/metrics.html#deepchem.metrics.pearson_r2_score) . Here loss is computed directly from the model's output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "Cl7op_CoHBed"
      },
      "outputs": [],
      "source": [
        "tasks, dataset, transformers = dc.molnet.load_delaney(featurizer='ECFP', splitter='random')\n",
        "train_set, valid_set, test_set = dataset\n",
        "metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3JeD7uJVhhLx"
      },
      "source": [
        "## Time to Train\n",
        "\n",
        "We train our model for 50 epochs, with L2 as Loss Function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1fpHI3eQnTWO",
        "outputId": "46b3c583-3da2-484c-e0a5-cf0066f6df7b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training set score :  {'pearson_r2_score': 0.9708644701066554}\n",
            "Test set score :  {'pearson_r2_score': 0.7104556551957734}\n"
          ]
        }
      ],
      "source": [
        "# Like mentioned before one can use GPUs with PyTorch and torchdiffeq\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "model = ODENet(in_dim=1024, mid_dim=1000, out_dim=1).to(device)\n",
        "model = dc.models.TorchModel(model, dc.models.losses.L2Loss())\n",
        "\n",
        "model.fit(train_set, nb_epoch=50)\n",
        "\n",
        "print('Training set score : ', model.evaluate(train_set,[metric]))\n",
        "print('Test set score : ', model.evaluate(test_set,[metric]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oYXu5gq9Ax9r"
      },
      "source": [
        "Neural ODEs are invertible neural nets [Reference](https://proceedings.mlr.press/v119/zhang20h.html#:~:text=Neural%20ODEs%20and%20i%2DResNets,.approximate%20any%20continuous%20invertible%20mapping.)\n",
        "Invertible neural networks have been a significant thread of research in the ICML community for several years. Such transformations can offer a range of unique benefits: \n",
        "\n",
        "\n",
        "\n",
        "*   They preserve information, allowing perfect reconstruction (up to numerical limits) and obviating the need to store hidden activations in memory for backpropagation.  \n",
        "*    They are often designed to track the changes in probability density that applying the transformation induces (as in normalizing flows). \n",
        "* Like autoregressive models, [normalizing flows](https://arxiv.org/pdf/2006.00104.pdf) can be powerful generative models which allow exact likelihood computations; with the right architecture, they can also allow for much cheaper sampling than autoregressive models. \n",
        "\n",
        "While many researchers are aware of these topics and intrigued by several high-profile papers, few are familiar enough with the technical details to easily follow new developments and contribute. Many may also be unaware of the wide range of applications of invertible neural networks, beyond generative modelling and variational inference."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7PIEFIAiKvRP"
      },
      "source": [
        "# Congratulations! Time to join the Community!\n",
        "\n",
        "Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:\n",
        "\n",
        "## Star DeepChem on [GitHub](https://github.com/deepchem/deepchem)\n",
        "This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.\n",
        "\n",
        "## Join the DeepChem Discord\n",
        "The DeepChem [Discord](https://discord.gg/cGzwCdrUqS) hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "About nODE : Using Torchdiffeq in Deepchem",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
